package com.senlin;


import java.io.IOException;
import java.io.InputStream;
import java.io.RandomAccessFile;
import java.net.HttpURLConnection;
import java.net.URI;
import java.net.http.HttpClient;
import java.net.http.HttpHeaders;
import java.net.http.HttpRequest;
import java.net.http.HttpResponse;
import java.nio.ByteBuffer;
import java.nio.channels.Channels;
import java.nio.channels.FileChannel;
import java.nio.channels.ReadableByteChannel;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.time.Duration;
import java.time.LocalDateTime;
import java.util.*;
import java.util.concurrent.*;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Function;
import java.util.stream.Collectors;

/**
 * @Author: toor
 * @Date: Created in 11:23 2021/4/1
 */
public class Thunder extends Fetch {

    public static Thunder me() {
        return new Thunder();
    }

    public Thunder sharingNum(long num) {
        super.sharingUnitFunc(size -> {
           // unit must gt 1kb
           long unit;
           if (num < 0 || (unit = size / num) < 1024) {
               return null;
           }
           return unit;
        });
        return this;
    }

    public Thunder sharingUnitByKB(long unit) {
        super.sharingUnitFunc(size -> unit);
        return this;
    }
}

/**
 * 下载处理器
 */
interface DownloadHandler {

    /**
     * 分片回调
     * @param id        下载任务编号
     * @param size      下载大小
     * @param isRange   是否支持分片
     * @param address   下载地址
     * @return 返回分片集合
     */
    List<ShardingTask> partition(long id, long size, boolean isRange, String address);

    /**
     * 数据回调，处理数据
     * @param data              真实数据
     * @param downloadTask      下载任务
     * @param shardingTask      分片任务
     */
    void data(InputStream data, DownloadTask downloadTask, ShardingTask shardingTask) throws Exception;

    /**
     * 任务完成回调
     * @param task  下载任务
     */
    void completed(DownloadTask task);

    /**
     * 异常回调
     * @param e     异常
     * @param task  任务
     */
    void failure(Throwable e, Task task);
}

interface Task {
    long getId();
    String getAddress();
    Fetch.HttpMethod getMethod();
    boolean isRange();
    long getReties();
    void incrReties();
}

/**
 * 下载任务
 */
class DownloadTask implements Task {
    // 下载任务编号
    private final long id;
    // 下载地址
    private final String address;
    // 写入位置
    private String location;
    // 计数器
    private final AtomicInteger count;
    // 下载文件类型
    private String mine;
    // 下载文件大小
    private String size;
    // 创建时间
    private LocalDateTime createdTime;
    // 完成时间
    private LocalDateTime completedTime;
    // 重试次数
    private long reties = 0;
    // 未完成 Task （这里必须保证线程安全）
    private final AtomicInteger unfinishedTaskNum;


    public DownloadTask(long id, String address, String location) {
        this.id = id;
        this.address = address;
        this.location = location;
        this.count = new AtomicInteger(0);
        this.unfinishedTaskNum = new AtomicInteger(-1);
    }

    public static DownloadTask create(long id, String address, String location) {
        return new DownloadTask(id, address, location);
    }

    @Override
    public long getReties() {
        return reties;
    }

    @Override
    public void incrReties() {
        reties ++;
    }

    @Override
    public long getId() {
        return id;
    }

    @Override
    public String getAddress() {
        return address;
    }

    public String getLocation() {
        return location;
    }

    public AtomicInteger getCount() {
        return count;
    }

    public String getMine() {
        return mine;
    }

    public void setMine(String mine) {
        this.mine = mine;
    }

    public String getSize() {
        return size;
    }

    public void setSize(String size) {
        this.size = size;
    }

    public LocalDateTime getCreatedTime() {
        return createdTime;
    }

    public void setCreatedTime(LocalDateTime createdTime) {
        this.createdTime = createdTime;
    }

    public LocalDateTime getCompletedTime() {
        return completedTime;
    }

    public void setCompletedTime(LocalDateTime completedTime) {
        this.completedTime = completedTime;
    }

    public void setLocation(String location) {
        this.location = location;
    }

    public AtomicInteger getUnfinishedTaskNum() {
        return unfinishedTaskNum;
    }

    @Override
    public String toString() {
        return "DownloadTask{" +
                "id=" + id +
                ", address='" + address + '\'' +
                ", location='" + location + '\'' +
                ", count=" + count +
                ", mine='" + mine + '\'' +
                ", size='" + size + '\'' +
                ", createdTime=" + createdTime +
                ", completedTime=" + completedTime +
                '}';
    }

    @Override
    public Fetch.HttpMethod getMethod() {
        return Fetch.HttpMethod.HEAD;
    }

    @Override
    public boolean isRange() {
        return false;
    }

    // success
    public boolean isSuccessful() {
        return unfinishedTaskNum.get() == 0;
    }
}


/**
 * 切片任务
 */
class ShardingTask implements Task {
    // 下载任务编号
    private long id;
    // 起始位置
    private long start;
    // 结束位置
    private long end;
    // 偏移量，从1开始
    private long offset;
    // 总切片数
    private long num;
    // 总大小（byte）
    private long total;
    // 请求地址
    private String address;
    // 请求方法
    private final Fetch.HttpMethod method;
    // 重数计数
    private long reties = 0;

    public ShardingTask(long id, long start, long end, long offset, long num, long total, String address, Fetch.HttpMethod method) {
        this.id = id;
        this.start = start;
        this.end = end;
        this.offset = offset;
        this.num = num;
        this.total = total;
        this.address = address;
        this.method = method;
    }

    @Override
    public long getReties() {
        return reties;
    }

    @Override
    public void incrReties() {
        reties ++;
    }

    @Override
    public long getId() {
        return id;
    }

    public long getStart() {
        return start;
    }

    public long getEnd() {
        return end;
    }

    public long getOffset() {
        return offset;
    }

    public long getNum() {
        return num;
    }

    public long getTotal() {
        return total;
    }

    @Override
    public String getAddress() {
        return address;
    }

    @Override
    public Fetch.HttpMethod getMethod() {
        return method;
    }

    @Override
    public boolean isRange() {
        return !(start == 0 && end == 0);
    }

    @Override
    public String toString() {
        return "ShardingTask{" +
                "id=" + id +
                ", start=" + start +
                ", end=" + end +
                ", offset=" + offset +
                ", num=" + num +
                ", total=" + total +
                ", address='" + address + '\'' +
                ", method=" + method +
                '}';
    }
}

class DefaultDownloadHandler implements DownloadHandler {

    @Override
    public List<ShardingTask> partition(long id, long size, boolean isRange, String address) {
        if (size <= 0) {
            return Collections.emptyList();
        }

        if (!isRange) {
            return List.of(new ShardingTask(id, 0, 0, 1, 1, size, address, Fetch.HttpMethod.GET));

        } else {
            final long unit = unit(size);
            final long count = size / unit;

            final List<ShardingTask> tasks = new ArrayList<>((int) count);

            long start = 0, offset = 1;
            for (; offset < count; offset ++) {
                long end = unit + start;
                tasks.add(new ShardingTask(id, start, end, offset, count, size, address, Fetch.HttpMethod.GET));
                start = end + 1;
            }
            tasks.add(new ShardingTask(id, start, size, offset, count, size, address, Fetch.HttpMethod.GET));
            return tasks;
        }
    }

    @Override
    public void data(InputStream inputStream, DownloadTask downloadTask, ShardingTask shardingTask) throws Exception {
        final Path path = Paths.get(downloadTask.getLocation());
        try (final ReadableByteChannel inChannel = Channels.newChannel(inputStream);
             final RandomAccessFile os = new RandomAccessFile(path.toFile(), "rw");
             final FileChannel osChannel = os.getChannel()) {

            long position = shardingTask.getStart();
            final ByteBuffer allocate = ByteBuffer.allocate(8192);
            int len;
            final long start = System.currentTimeMillis();
            while ((len = inChannel.read(allocate)) != -1) {
                allocate.flip();
                osChannel.write(allocate, position);
                allocate.compact();
                position += len;
            }
            final long end = System.currentTimeMillis();
            System.out.println("[下载完成@" + shardingTask.getId() + "#" + shardingTask.getOffset() + "] - 耗时:" + (end - start) + "ms - " + shardingTask);

        } catch (IOException e) {
            throw e;
        }
    }

    @Override
    public void completed(DownloadTask task) {
        task.setCompletedTime(LocalDateTime.now());
        final Duration between = Duration.between(task.getCreatedTime(), task.getCompletedTime());
        System.out.println("------------------------- [下完了@" + Paths.get(task.getLocation()).getFileName().toString() + "] --- 状态：" + task.isSuccessful() + " --- 耗时：" + between.toMillis() +" ms ----------------------------");
    }

    @Override
    public void failure(Throwable e, Task task) {
        // ignore
    }

    protected long unit(long size) {
        return  4L * 1024 * 1024;
    }
}


abstract class Fetch {
    // todo 推荐使用自定义得线程池
    protected static int threadNum = (Runtime.getRuntime().availableProcessors() * 2 + 1) * 2;
    protected static ThreadPoolExecutor executorService;
    protected static Semaphore serviceSemaphore;
    protected static Semaphore downloadTaskSemaphore;
    protected static ArrayBlockingQueue<Task> taskQueue = new ArrayBlockingQueue<>(10_0000);
    protected static ConcurrentHashMap<Long, DownloadTask> downloadMaps = new ConcurrentHashMap<>(10000);
    protected Timer timer = new Timer();
    protected long downloadNo = 0;
    protected Map<String, String> mimeMaps;
    protected static HttpClient client;
    protected long MAX_RETIES = 3;
    protected DownloadHandler callback;
    protected Function<Long, Long> computeSharingUnitFunc = size -> 4L * 1024 * 1024;
    protected FMode mode = FMode.DMSM;


    private void init() {
        serviceSemaphore = new Semaphore(threadNum);
        executorService = (ThreadPoolExecutor) Executors.newFixedThreadPool(threadNum);

        client = HttpClient.newBuilder()
                .version(HttpClient.Version.HTTP_1_1)
                .followRedirects(HttpClient.Redirect.NORMAL)
                .connectTimeout(Duration.ofSeconds(60))
                .executor(executorService)
                .build();

        if (callback == null) {
            callback = new DefaultDownloadHandler() {
                @Override
                protected long unit(long size) {
                    Long sharingUnitSize = computeSharingUnitFunc.apply(size);
                    return sharingUnitSize == null || sharingUnitSize < 1024 ? 4L * 1024 * 1024 : sharingUnitSize;
                }
            };
        }

        mimeMaps = loadingMime();
    }

    private void initTaskQueue() {
        downloadMaps.values().stream().sorted(Comparator.comparing(DownloadTask::getId)).forEach(taskQueue::add);
    }

    private void initDownloadTaskSemaphore() {
        downloadTaskSemaphore = new Semaphore(1);
    }

    private void initMonitor(Thread mainThread) {
        timer.schedule(new TimerTask() {
            @Override
            public void run() {
                if (Fetch.this.isCompleted()) {
                    System.out.println("#################### 关闭线程池");
                    executorService.shutdownNow();
                    mainThread.interrupt();
                    timer.cancel();
                }
            }
        }, 1000, 5000);
    }

    public void run() {
        init();
        initMonitor(Thread.currentThread());
        initTaskQueue();

        switch (mode) {
            case DMSM:
                runModeByDMSM();
                break;

            case D1SM:
                initDownloadTaskSemaphore();
                runModeByD1SM();
                break;
        }

        System.out.println("#################### 已下载：");
        final String stat = downloadMaps.values().stream().map(DownloadTask::getLocation).collect(Collectors.joining("， ", "【", "】"));
        System.out.println(stat);
    }

    private void looper(FRunner runner) {
        while (!executorService.isShutdown() && !Thread.currentThread().isInterrupted()) {
            try {
                runner.run(taskQueue.take());
            } catch (InterruptedException e) {
                System.out.println("#################### 正在退出...");
                Thread.currentThread().interrupt();
            } catch (Exception e) {
                // ignore
            }
        }
    }

    private void realRunning(Task task) throws InterruptedException {
        serviceSemaphore.acquire();
        this.download(task);
    }

    private void runModeByDMSM() {
       looper(this::realRunning);
    }

    private void runModeByD1SM() {
        looper((task) -> {
            if (task instanceof DownloadTask) {
                // todo 存在空轮询风险！有解决方案，下次优化
                if (!downloadTaskSemaphore.tryAcquire()) {
                    taskQueue.add(task);
                    return;
                }
            }

            realRunning(task);
        });
    }

    private boolean isCompleted() {
        return downloadMaps.values()
                .stream()
                .map(DownloadTask::getCount)
                .allMatch(c -> c.compareAndSet(-1, -1));
    }

    public Fetch mode(FMode mode) {
        this.mode = mode;
        return this;
    }

    public Fetch memoryCacheMode() {
        this.mode = FMode.D1SM;
        return this;
    }

    public Fetch single() {
        threadNum = 1;
        return this;
    }

    public Fetch multi(int threadNum) {
        if (threadNum > 0) {
            Fetch.threadNum = threadNum;
        }
        return this;
    }

    public Fetch reties(long reties) {
        if (reties > 0) {
            MAX_RETIES = reties;
        }
        return this;
    }

    protected Fetch sharingUnitFunc(Function<Long, Long> fuc) {
        if (fuc != null) {
            this.computeSharingUnitFunc = fuc;
        }
        return this;
    }

    public Fetch downloadHandler(DownloadHandler downloadHandler) {
        this.callback = downloadHandler;
        return this;
    }

    public Fetch schedule(String address, String location) {
        final DownloadTask downloadTask = new DownloadTask(++ downloadNo, address, location);
        downloadMaps.put(downloadTask.getId(), downloadTask);
        return this;
    }

    private void download(Task task) {
        client.sendAsync(buildRequest(task), HttpResponse.BodyHandlers.ofInputStream())
                .whenComplete((resp, err) -> this.processor(resp, err, task))
                .thenRun(serviceSemaphore::release);
    }


    private HttpRequest buildRequest(Task task) {
        HttpRequest.Builder builder = HttpRequest.newBuilder();
        builder.uri(URI.create(task.getAddress()));
        builder.timeout(Duration.ofMinutes(2));
        builder.header("User-Agent", "Chrome 8.90");
        builder.method(task.getMethod().name(), HttpRequest.BodyPublishers.noBody());
        if (task.isRange() && task instanceof ShardingTask) {
            ShardingTask shardingTask = (ShardingTask) task;
            builder.header("Range", "bytes=" + shardingTask.getStart() + "-" + shardingTask.getEnd());
        }
        return builder.build();
    }


    private void processor(HttpResponse<InputStream> response, Throwable throwable, Task task) {
        if (response.statusCode() == HttpURLConnection.HTTP_OK ||
                (response.statusCode() > 200 && response.statusCode() < 300)) {
            if ("HEAD".equalsIgnoreCase(response.request().method())) {
                pushShardingTask(task.getId(), response.headers(), response.request().uri().toString());
            } else {
                final DownloadTask downloadTask = downloadMaps.get(task.getId());
                try (final InputStream in = response.body()) {
                    callback.data(in, downloadTask, (ShardingTask) task);
                } catch (Exception e) {
                    if (retry(e, task)) {
                        return;
                    }
                }

                downloadTask.getUnfinishedTaskNum().decrementAndGet();
                downloadTask.getCount().decrementAndGet();
                if (downloadTask.getCount().compareAndSet(0, -1)) {
                    callback.completed(downloadTask);
                    if (mode == FMode.D1SM) downloadTaskSemaphore.release();
                }
            }
        } else {
            retry(throwable, task);
            if (task instanceof DownloadTask && mode == FMode.D1SM) downloadTaskSemaphore.release();
        }
    }

    private boolean retry(Throwable e, Task errorTask) {
        boolean retryOk = false;
        final long reties = errorTask.getReties();
        if (reties < MAX_RETIES) {
            errorTask.incrReties();
            taskQueue.add(errorTask);
            retryOk = true;
        }
        callback.failure(e, errorTask);
        return retryOk;
    }


    private void pushShardingTask(long id, HttpHeaders headers, String address) {
        final String acceptRanges = headers.firstValue("Accept-Ranges").orElse("");
        final boolean isRange = Objects.equals(acceptRanges, "bytes");
        final long contentLength = headers.firstValueAsLong("Content-Length").orElse(0L);

        DownloadTask downloadTask = downloadMaps.get(id);
        List<ShardingTask> shardingTasks = callback.partition(id, contentLength, isRange, address);

        if (!shardingTasks.isEmpty()) {
            String mime = mimeMaps.getOrDefault(headers.firstValue("content-type").orElse(""), ".none");
            downloadTask.setLocation(this.checkingAndGetLocation(downloadTask.getLocation(), mime));
            downloadTask.getUnfinishedTaskNum().set(shardingTasks.size());
            downloadTask.getCount().set(shardingTasks.size());
            downloadTask.setCreatedTime(LocalDateTime.now());
            taskQueue.addAll(shardingTasks);
        }

        System.out.println("---------------- " + downloadTask + "----------------");
    }


    private String checkingAndGetLocation(String location, String mime) {
        Path path = Paths.get(location);
        try {
            if (Files.isDirectory(path)) {
                Files.createDirectories(path);
                path = path.resolve(UUID.randomUUID() + mime);
            } else {
                Files.createDirectories(path.getParent());
            }
        } catch (IOException e) {
            e.printStackTrace();
        }
        return path.toAbsolutePath().toString();
    }


    protected static Map<String, String> loadingMime() {
        try {
            return Files.readAllLines(Paths.get("mime.txt"))
                    .stream().map(l -> l.split("=")).collect(Collectors.toConcurrentMap(s -> s[1], s -> s[0], (n1, n2) -> n1));
        } catch (IOException e) {
            e.printStackTrace();
        }
        return Collections.emptyMap();
    }

    enum HttpMethod {
        GET, HEAD
    }

    enum FMode {
        DMSM, D1SM
    }

    interface FRunner {
        void run(Task task) throws Exception;
    }
}